import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import pickle
import time
import os
from tqdm import tqdm
from Utils import *
from cluster_clc import cluster_clc, cluster_only


class Trainer(object):
    """
    """
    def __init__(self, args, model, voc=None):
        super(Trainer, self).__init__()
        self.model = model.to(args.device)
        self.epoch = args.epoch
        self.data_name = args.dataname
        self.device = args.device
        self.topic_k = args.K
        self.test_every = args.test_every
        self.train_num = -1
        self.clc_num = args.clc_num
        self.layer_num = len(self.topic_k)

        log_str = 'runs/{}/k_{}'.format(args.dataname, self.topic_k)
        now = int(round(time.time() * 1000))
        now_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime(now / 1000))
        self.log_str = log_str + '/' + now_time
        if not os.path.exists(self.log_str):
            os.makedirs(self.log_str)

        self.trainable_params = []
        print('WeTe learnable params:')
        word_params = []
        for name, params in self.model.named_parameters():
            if params.requires_grad:
                print(name)
                self.trainable_params.append(params)
            else:
                print('errorrrrrrrrrrrrrrr', name)
        self.optimizer = torch.optim.AdamW(self.trainable_params, lr=args.lr, weight_decay=1e-3)
        # self.optimizer1 = torch.optim.AdamW(self.trainable_params + word_params, lr=args.lr, weight_decay=1e-3)

    def train(self, train_loader, test_loader):
        inner_epoch = 2
        print(f'upward layer-wise pretraining start')
        for epoch in range(1):
            tr_loss = []
            tr_forward_cost = []
            tr_backward_cost = []
            tr_tm = []
            # pbar = tqdm(enumerate(train_loader), total=len(train_loader))
            self.model.train()
            print(f'layer-wise training')
            for layer_id in range(self.model.layer_num):

                if layer_id == 0:
                    for i in range(inner_epoch):
                        pbar = tqdm(enumerate(train_loader), total=len(train_loader))
                        for j, (bow, tfidf, label) in pbar:
                            # bow_norm = bow / torch.sum(bow, dim=1, keepdim=True)
                            bow = bow.to(self.device).float()
                            self.train_num += 1
                            # tfidf = tfidf.to(self.device).float()
                            tfidf = tfidf.to(self.device).float()
                            # tfidf = bow

                            ### layer-wise training
                            theta = self.model.InferNet(bow)
                            self.model.update_embeddings()
                            phi = self.model.cal_phi()
                            rec_x = torch.matmul(theta[0], phi[0].t())
                            # tm_loss = self.model.Poisson_likelihood(bow, rec_x)
                            tm_loss = self.model.Entropy(tfidf, rec_x)
                            # tm_loss = self.model.Entropy_v1(bow, rec_x)

                            forward_cost, backward_cost = self.model.GCT(tfidf, theta[layer_id], self.model.rho.detach(),
                                                                         self.model.alpha[layer_id])
                            loss = 1.0 * tm_loss + forward_cost + backward_cost

                            self.optimizer.zero_grad()
                            loss.backward()
                            for p in self.trainable_params:
                                try:
                                    p.grad = p.grad.where(~torch.isnan(p.grad), torch.tensor(0., device=p.grad.device))
                                    p.grad = p.grad.where(~torch.isinf(p.grad), torch.tensor(0., device=p.grad.device))
                                    nn.utils.clip_grad_norm_(p, max_norm=20, norm_type=2)
                                except:
                                    pass
                            self.optimizer.step()
                            tr_loss.append(loss.item())
                            tr_forward_cost.append(forward_cost.item())
                            tr_backward_cost.append(backward_cost.item())
                            tr_tm.append(tm_loss.item())
                            pbar.set_description(
                                f'layer_id: {layer_id}, epoch: {epoch}|{self.epoch}|{i}|{inner_epoch}, loss: {np.mean(tr_loss):.4f},  forword_cost: {np.mean(tr_forward_cost):.4f},  '
                                f'backward_cost: {np.mean(tr_backward_cost):.4f}, TM_loss: {np.mean(tr_tm):.4f}')

                        if (i+1) % 1 == 0:
                            print(f'save voc')
                            phi_cpu = [each.cpu().detach().numpy() for each in phi]
                            vision_phi(phi_cpu, outpath=f'{self.log_str}/pre_{epoch}/{layer_id}/{i}',
                                       voc=self.model.voc)
                else:
                    # if True and (epoch == 0):
                    #     print(f'\n init topic embedding via previous layer \n')
                    #     topic_e = self.model.alpha[layer_id-1].cpu().detach().numpy()
                    #     topic_init = cluster_kmeans(topic_e, self.topic_k[layer_id])
                    #     self.model.topic_layer[layer_id] = self.model.topic_layer[layer_id].from_pretrained(torch.from_numpy(topic_init).float(), freeze=False).to(self.device)

                    for i in range(inner_epoch):
                        pbar = tqdm(enumerate(train_loader), total=len(train_loader))
                        for j, (bow, tfidf, label) in pbar:
                            # bow_norm = bow / torch.sum(bow, dim=1, keepdim=True)
                            bow = bow.to(self.device).float()
                            self.train_num += 1
                            # tfidf = tfidf.to(self.device).float()
                            tfidf = tfidf.to(self.device).float()
                            # tfidf = bow

                            ### layer-wise training
                            theta = self.model.InferNet(bow)
                            self.model.update_embeddings()
                            phi = self.model.cal_phi()
                            # phi_layer = self.model.cal_phi_layer(layer_id)
                            rec_x = torch.matmul(theta[0], phi[0].t())
                            # tm_loss = [self.model.Poisson_likelihood(bow, rec_x)]
                            tm_loss = [self.model.Entropy(tfidf, rec_x)]
                            # tm_loss = self.model.Entropy_v1(bow, rec_x)


                            forward_c_list = []
                            backward_c_list = []
                            forward_cost, backward_cost = self.model.GCT(tfidf, theta[0], self.model.rho.detach(),
                                                                         self.model.alpha[0])
                            forward_c_list.append(forward_cost)
                            backward_c_list.append(backward_cost)
                            for inner_layer_id in range(layer_id):
                                # phi_layer = self.model.cal_phi_layer(inner_layer_id+1)
                                # rec_x = torch.matmul(theta[inner_layer_id+1], phi_layer.t())
                                # tm_loss.append(self.model.Poisson_likelihood(bow, rec_x))

                                # phi_layer = self.model.cal_phi_layer(inner_layer_id + 1)
                                rec_x = torch.matmul(theta[inner_layer_id + 1], phi[inner_layer_id+1].t())
                                tm_loss.append(self.model.Entropy(theta[inner_layer_id].detach(), rec_x))

                                forward_cost, backward_cost = self.model.GCT(theta[inner_layer_id].detach(), theta[inner_layer_id+1], self.model.alpha[inner_layer_id].detach(),
                                                                             self.model.alpha[inner_layer_id + 1])
                                forward_c_list.append(forward_cost)
                                backward_c_list.append(backward_cost)
                            forward_cost = torch.sum(torch.stack(forward_c_list))
                            backward_cost = torch.sum(torch.stack(backward_c_list))
                            tm_loss = torch.mean(torch.stack(tm_loss))
                            loss =1.0 * tm_loss + forward_cost + backward_cost

                            self.optimizer.zero_grad()
                            loss.backward()
                            for p in self.trainable_params:
                                try:
                                    p.grad = p.grad.where(~torch.isnan(p.grad), torch.tensor(0., device=p.grad.device))
                                    p.grad = p.grad.where(~torch.isinf(p.grad), torch.tensor(0., device=p.grad.device))
                                    nn.utils.clip_grad_norm_(p, max_norm=20, norm_type=2)
                                except:
                                    pass
                            self.optimizer.step()
                            tr_loss.append(loss.item())
                            tr_forward_cost.append(forward_cost.item())
                            tr_backward_cost.append(backward_cost.item())
                            tr_tm.append(tm_loss.item())
                            pbar.set_description(
                                f'layer_id: {layer_id}, epoch: {epoch}|{self.epoch}|{i}|{inner_epoch}, loss: {np.mean(tr_loss):.4f},  forword_cost: {np.mean(tr_forward_cost):.4f},  '
                                f'backward_cost: {np.mean(tr_backward_cost):.4f}, TM_loss: {np.mean(tr_tm):.4f}')
                        if (i+1) % 1 ==0:
                            print(f'save voc')
                            phi_cpu = [each.cpu().detach().numpy() for each in phi]
                            vision_phi(phi_cpu, outpath=f'{self.log_str}/pre_{epoch}/{layer_id}/{i}',
                                       voc=self.model.voc)
        if self.layer_num > 1:
            print(f'downward refining')
            for epoch in range(self.epoch):
                self.model.train()
                pbar = tqdm(enumerate(train_loader), total=len(train_loader))
                tm_loss_tr = []
                ct_loss_tr = []
                tr_loss = []
                for j, (bow, tfidf, label) in pbar:
                    # bow_norm = bow / torch.sum(bow, dim=1, keepdim=True)
                    bow = bow.to(self.device).float()
                    self.train_num += 1
                    # tfidf = tfidf.to(self.device).float()
                    tfidf = tfidf.to(self.device).float()
                    # tfidf = bow

                    ### layer-wise training
                    theta = self.model.InferNet(bow)
                    self.model.update_embeddings()
                    phi = self.model.cal_phi()
                    # rec_x = torch.matmul(theta[0], phi[0].t())
                    # tm_loss = self.model.Poisson_likelihood(bow, rec_x)
                    tm_loss = []
                    layer_loss = []

                    for layer_id in range(self.model.layer_num-2, 0, -1):   ### for 4 layer model: 2,1
                        forward_cost_up, backward_cost_up = self.model.GCT(theta[layer_id], theta[layer_id + 1].detach(),
                                                                     self.model.alpha[layer_id],
                                                                     self.model.alpha[layer_id + 1].detach())
                        forward_cost_down, backward_cost_down = self.model.GCT(theta[layer_id], theta[layer_id -1].detach(),
                                                                           self.model.alpha[layer_id],
                                                                           self.model.alpha[layer_id -1].detach())
                        each_loss = 0.5 * forward_cost_up + 0.5 * backward_cost_up + 0.5 * forward_cost_down + 0.5 * backward_cost_down
                        # phi_layer = self.model.cal_phi_layer(layer_id)
                        # re_x = torch.matmul(theta[layer_id], phi_layer.t())
                        # tm_loss.append(1.0 * self.model.Poisson_likelihood(bow, re_x))

                        rec_x = torch.matmul(theta[layer_id], phi[layer_id].t())
                        tm_loss.append(self.model.Entropy(theta[layer_id-1].detach(), rec_x))

                        layer_loss.append(each_loss)

                    forward_cost_up, backward_cost_up = self.model.GCT(theta[0], theta[1].detach(),
                                                                       self.model.alpha[0],
                                                                       self.model.alpha[1].detach())
                    forward_cost_down, backward_cost_down = self.model.GCT(theta[0], tfidf,
                                                                           self.model.alpha[0],
                                                                           self.model.rho.detach())
                    each_loss = forward_cost_up + backward_cost_up + forward_cost_down + backward_cost_down
                    layer_loss.append(each_loss)

                    forward_cost_down, backward_cost_down = self.model.GCT(theta[-1], theta[-2].detach(),
                                                                           self.model.alpha[-1],
                                                                           self.model.alpha[-2].detach())
                    each_loss = forward_cost_down + backward_cost_down
                    layer_loss.append(each_loss)
                    re_x = torch.matmul(theta[0], phi[0].t())
                    # tm_loss.append(1.0 * self.model.Poisson_likelihood(bow, re_x))
                    tm_loss.append(1.0 * self.model.Entropy(tfidf, re_x))
                    # tm_loss = self.model.Entropy_v1(bow, rec_x)

                    re_x = torch.matmul(theta[-1], phi[-1].t())
                    tm_loss.append(1.0 * self.model.Entropy(theta[-2].detach(), re_x))


                    tm_loss = torch.mean(torch.stack(tm_loss))
                    layer_loss = torch.sum(torch.stack(layer_loss))
                    loss = 1.0 * tm_loss + layer_loss
                    self.optimizer.zero_grad()
                    loss.backward()
                    for p in self.trainable_params:
                        try:
                            p.grad = p.grad.where(~torch.isnan(p.grad), torch.tensor(0., device=p.grad.device))
                            p.grad = p.grad.where(~torch.isinf(p.grad), torch.tensor(0., device=p.grad.device))
                            nn.utils.clip_grad_norm_(p, max_norm=20, norm_type=2)
                        except:
                            pass
                    self.optimizer.step()
                    tm_loss_tr.append(tm_loss.item())
                    ct_loss_tr.append(layer_loss.item())
                    tr_loss.append(loss.item())
                    pbar.set_description(
                        f'epoch: {epoch}|{self.epoch}, loss: {np.mean(tr_loss):.4f},  ct loss: {np.mean(ct_loss_tr):.4f},  '
                        f'TM_loss: {np.mean(tm_loss_tr):.4f}')

                if epoch % self.test_every == 0:
                    self.test(epoch, test_loader)
        else:
            layer_id = 0
            for epoch in range(self.epoch):
                self.model.train()
                pbar = tqdm(enumerate(train_loader), total=len(train_loader))
                for j, (bow, tfidf, label) in pbar:
                    # bow_norm = bow / torch.sum(bow, dim=1, keepdim=True)
                    bow = bow.to(self.device).float()
                    self.train_num += 1
                    # tfidf = tfidf.to(self.device).float()
                    tfidf = tfidf.to(self.device).float()
                    # tfidf = bow

                    ### layer-wise training
                    theta = self.model.InferNet(bow)
                    self.model.update_embeddings()
                    phi = self.model.cal_phi()
                    rec_x = torch.matmul(theta[0], phi[0].t())
                    # tm_loss = self.model.Poisson_likelihood(bow, rec_x)
                    tm_loss = self.model.Entropy(tfidf, rec_x)
                    # tm_loss = self.model.Entropy_v1(bow, rec_x)

                    forward_cost, backward_cost = self.model.GCT(tfidf, theta[layer_id], self.model.rho.detach(),
                                                                 self.model.alpha[layer_id])
                    loss = 1.0 * tm_loss + forward_cost + backward_cost

                    self.optimizer.zero_grad()
                    loss.backward()
                    for p in self.trainable_params:
                        try:
                            p.grad = p.grad.where(~torch.isnan(p.grad), torch.tensor(0., device=p.grad.device))
                            p.grad = p.grad.where(~torch.isinf(p.grad), torch.tensor(0., device=p.grad.device))
                            nn.utils.clip_grad_norm_(p, max_norm=20, norm_type=2)
                        except:
                            pass
                    self.optimizer.step()
                    tr_loss.append(loss.item())
                    tr_forward_cost.append(forward_cost.item())
                    tr_backward_cost.append(backward_cost.item())
                    tr_tm.append(tm_loss.item())
                    pbar.set_description(
                        f'layer_id: {layer_id}, epoch: {epoch}|{self.epoch}|{i}|{inner_epoch}, loss: {np.mean(tr_loss):.4f},  forword_cost: {np.mean(tr_forward_cost):.4f},  '
                        f'backward_cost: {np.mean(tr_backward_cost):.4f}, TM_loss: {np.mean(tr_tm):.4f}')

                if epoch % self.test_every == 0:
                    self.test(epoch, test_loader)


    def test(self, epoch, test_loader):
        self.model.eval()
        train_theta = None
        train_label = None
        test_theta = None
        test_label = None
        tr_loss = []
        tr_forward_cost = []
        tr_backward_cost = []
        tr_tm = []
        te_loss = []
        te_forward_cost = []
        te_backward_cost = []
        te_tm = []
        with torch.no_grad():
            ### visualize topics and save embeddings
            self.model.update_embeddings()
            phi = self.model.cal_phi()
            phi_cpu = [each.cpu().detach().numpy() for each in phi]
            vision_phi(phi_cpu, outpath=f'{self.log_str}/{epoch}_refine',
                       voc=self.model.voc)

            self.model.save_embeddings(f'{self.log_str}/{epoch}_refine/embedding.pkl')

            for j, (bow, tfidf, label) in enumerate(test_loader):
                bow = bow.to(self.device).float()
                # tfidf = tfidf.to(self.device).float()
                tfidf = tfidf.to(self.device).float()
                ### layer-wise training
                theta = self.model.InferNet(bow)[0]
                if test_theta is None:
                    test_theta = theta.cpu().detach().numpy()
                    test_label = label.detach().numpy()
                else:
                    test_theta = np.concatenate((test_theta, theta.cpu().detach().numpy()))
                    test_label = np.concatenate((test_label, label.detach().numpy()))
        purity_value, nmi_value = cluster_only(test_theta, test_label, self.clc_num)
        print(f'*************************** Test Summary **************************')
        print(f'Epoch {epoch}|{self.epoch}\n'
              f'Clustering, purity: {purity_value:.4f}, nmi: {nmi_value:.4f}\n')




    def train_image(self, train_loader):
        inner_epoch = 3
        print(f'upward layer-wise pretraining start')
        for epoch in range(1):
            tr_loss = []
            tr_forward_cost = []
            tr_backward_cost = []
            tr_tm = []
            # pbar = tqdm(enumerate(train_loader), total=len(train_loader))
            self.model.train()
            print(f'layer-wise training')
            for layer_id in range(self.model.layer_num):

                if layer_id == 0:
                    for i in range(inner_epoch):
                        pbar = tqdm(enumerate(train_loader), total=len(train_loader))
                        for j, tfidf in pbar:
                            # bow_norm = bow / torch.sum(bow, dim=1, keepdim=True)
                            # bow = bow.to(self.device).float()
                            self.train_num += 1
                            # tfidf = tfidf.to(self.device).float()
                            tfidf = tfidf.to(self.device).float()
                            # tfidf = bow

                            ### layer-wise training
                            theta = self.model.InferNet(tfidf)
                            self.model.update_embeddings()
                            phi = self.model.cal_phi()
                            rec_x = torch.matmul(theta[0], phi[0].t())
                            # tm_loss = self.model.Poisson_likelihood(bow, rec_x)
                            tm_loss = self.model.Entropy(tfidf, rec_x)
                            # tm_loss = self.model.Entropy_v1(bow, rec_x)

                            forward_cost, backward_cost = self.model.GCT(tfidf, theta[layer_id], self.model.rho.detach(),
                                                                         self.model.alpha[layer_id])
                            loss = 1.0 * tm_loss + forward_cost + backward_cost

                            self.optimizer.zero_grad()
                            loss.backward()
                            for p in self.trainable_params:
                                try:
                                    p.grad = p.grad.where(~torch.isnan(p.grad), torch.tensor(0., device=p.grad.device))
                                    p.grad = p.grad.where(~torch.isinf(p.grad), torch.tensor(0., device=p.grad.device))
                                    nn.utils.clip_grad_norm_(p, max_norm=20, norm_type=2)
                                except:
                                    pass
                            self.optimizer.step()
                            tr_loss.append(loss.item())
                            tr_forward_cost.append(forward_cost.item())
                            tr_backward_cost.append(backward_cost.item())
                            tr_tm.append(tm_loss.item())
                            pbar.set_description(
                                f'layer_id: {layer_id}, epoch: {epoch}|{self.epoch}|{i}|{inner_epoch}, loss: {np.mean(tr_loss):.4f},  forword_cost: {np.mean(tr_forward_cost):.4f},  '
                                f'backward_cost: {np.mean(tr_backward_cost):.4f}, TM_loss: {np.mean(tr_tm):.4f}')

                        if (i+1) % 1 == 0:
                            print(f'save voc')
                            phi_cpu = [each.cpu().detach().numpy() for each in phi]
                            vision_phi(phi_cpu, outpath=f'{self.log_str}/pre_{epoch}/{layer_id}/{i}',
                                       voc=self.model.voc)
                else:
                    # if True and (epoch == 0):
                    #     print(f'\n init topic embedding via previous layer \n')
                    #     topic_e = self.model.alpha[layer_id-1].cpu().detach().numpy()
                    #     topic_init = cluster_kmeans(topic_e, self.topic_k[layer_id])
                    #     self.model.topic_layer[layer_id] = self.model.topic_layer[layer_id].from_pretrained(torch.from_numpy(topic_init).float(), freeze=False).to(self.device)

                    for i in range(inner_epoch):
                        pbar = tqdm(enumerate(train_loader), total=len(train_loader))
                        for j, tfidf in pbar:
                            # bow_norm = bow / torch.sum(bow, dim=1, keepdim=True)
                            # bow = bow.to(self.device).float()
                            self.train_num += 1
                            # tfidf = tfidf.to(self.device).float()
                            tfidf = tfidf.to(self.device).float()
                            # tfidf = bow

                            ### layer-wise training
                            theta = self.model.InferNet(tfidf)
                            self.model.update_embeddings()
                            phi = self.model.cal_phi()
                            # phi_layer = self.model.cal_phi_layer(layer_id)
                            rec_x = torch.matmul(theta[0], phi[0].t())
                            # tm_loss = [self.model.Poisson_likelihood(bow, rec_x)]
                            tm_loss = [self.model.Entropy(tfidf, rec_x)]
                            # tm_loss = self.model.Entropy_v1(bow, rec_x)


                            forward_c_list = []
                            backward_c_list = []
                            forward_cost, backward_cost = self.model.GCT(tfidf, theta[0], self.model.rho.detach(),
                                                                         self.model.alpha[0])
                            forward_c_list.append(forward_cost)
                            backward_c_list.append(backward_cost)
                            for inner_layer_id in range(layer_id):
                                # phi_layer = self.model.cal_phi_layer(inner_layer_id+1)
                                # rec_x = torch.matmul(theta[inner_layer_id+1], phi_layer.t())
                                # tm_loss.append(self.model.Poisson_likelihood(bow, rec_x))

                                # phi_layer = self.model.cal_phi_layer(inner_layer_id + 1)
                                rec_x = torch.matmul(theta[inner_layer_id + 1], phi[inner_layer_id+1].t())
                                tm_loss.append(self.model.Entropy(theta[inner_layer_id].detach(), rec_x))

                                forward_cost, backward_cost = self.model.GCT(theta[inner_layer_id].detach(), theta[inner_layer_id+1], self.model.alpha[inner_layer_id].detach(),
                                                                             self.model.alpha[inner_layer_id + 1])
                                forward_c_list.append(forward_cost)
                                backward_c_list.append(backward_cost)
                            forward_cost = torch.sum(torch.stack(forward_c_list))
                            backward_cost = torch.sum(torch.stack(backward_c_list))
                            tm_loss = torch.mean(torch.stack(tm_loss))
                            loss =1.0 * tm_loss + forward_cost + backward_cost

                            self.optimizer.zero_grad()
                            loss.backward()
                            for p in self.trainable_params:
                                try:
                                    p.grad = p.grad.where(~torch.isnan(p.grad), torch.tensor(0., device=p.grad.device))
                                    p.grad = p.grad.where(~torch.isinf(p.grad), torch.tensor(0., device=p.grad.device))
                                    nn.utils.clip_grad_norm_(p, max_norm=20, norm_type=2)
                                except:
                                    pass
                            self.optimizer.step()
                            tr_loss.append(loss.item())
                            tr_forward_cost.append(forward_cost.item())
                            tr_backward_cost.append(backward_cost.item())
                            tr_tm.append(tm_loss.item())
                            pbar.set_description(
                                f'layer_id: {layer_id}, epoch: {epoch}|{self.epoch}|{i}|{inner_epoch}, loss: {np.mean(tr_loss):.4f},  forword_cost: {np.mean(tr_forward_cost):.4f},  '
                                f'backward_cost: {np.mean(tr_backward_cost):.4f}, TM_loss: {np.mean(tr_tm):.4f}')
                        if (i+1) % 1 ==0:
                            print(f'save voc')
                            phi_cpu = [each.cpu().detach().numpy() for each in phi]
                            vision_phi(phi_cpu, outpath=f'{self.log_str}/pre_{epoch}/{layer_id}/{i}',
                                       voc=self.model.voc)
        if self.layer_num > 1:
            print(f'downward refining')
            for epoch in range(self.epoch):
                self.model.train()
                pbar = tqdm(enumerate(train_loader), total=len(train_loader))
                tm_loss_tr = []
                ct_loss_tr = []
                tr_loss = []
                for j, tfidf in pbar:
                    # bow_norm = bow / torch.sum(bow, dim=1, keepdim=True)
                    # bow = bow.to(self.device).float()
                    self.train_num += 1
                    # tfidf = tfidf.to(self.device).float()
                    tfidf = tfidf.to(self.device).float()
                    # tfidf = bow

                    ### layer-wise training
                    theta = self.model.InferNet(tfidf)
                    self.model.update_embeddings()
                    phi = self.model.cal_phi()
                    # rec_x = torch.matmul(theta[0], phi[0].t())
                    # tm_loss = self.model.Poisson_likelihood(bow, rec_x)
                    tm_loss = []
                    layer_loss = []

                    for layer_id in range(self.model.layer_num-2, 0, -1):   ### for 4 layer model: 2,1
                        forward_cost_up, backward_cost_up = self.model.GCT(theta[layer_id], theta[layer_id + 1].detach(),
                                                                     self.model.alpha[layer_id],
                                                                     self.model.alpha[layer_id + 1].detach())
                        forward_cost_down, backward_cost_down = self.model.GCT(theta[layer_id], theta[layer_id -1].detach(),
                                                                           self.model.alpha[layer_id],
                                                                           self.model.alpha[layer_id -1].detach())
                        each_loss = 0.5 * forward_cost_up + 0.5 * backward_cost_up + 0.5 * forward_cost_down + 0.5 * backward_cost_down
                        # phi_layer = self.model.cal_phi_layer(layer_id)
                        # re_x = torch.matmul(theta[layer_id], phi_layer.t())
                        # tm_loss.append(1.0 * self.model.Poisson_likelihood(bow, re_x))

                        rec_x = torch.matmul(theta[layer_id], phi[layer_id].t())
                        tm_loss.append(self.model.Entropy(theta[layer_id-1].detach(), rec_x))

                        layer_loss.append(each_loss)

                    forward_cost_up, backward_cost_up = self.model.GCT(theta[0], theta[1].detach(),
                                                                       self.model.alpha[0],
                                                                       self.model.alpha[1].detach())
                    forward_cost_down, backward_cost_down = self.model.GCT(theta[0], tfidf,
                                                                           self.model.alpha[0],
                                                                           self.model.rho.detach())
                    each_loss = forward_cost_up + backward_cost_up + forward_cost_down + backward_cost_down
                    layer_loss.append(each_loss)

                    forward_cost_down, backward_cost_down = self.model.GCT(theta[-1], theta[-2].detach(),
                                                                           self.model.alpha[-1],
                                                                           self.model.alpha[-2].detach())
                    each_loss = forward_cost_down + backward_cost_down
                    layer_loss.append(each_loss)
                    re_x = torch.matmul(theta[0], phi[0].t())
                    # tm_loss.append(1.0 * self.model.Poisson_likelihood(bow, re_x))
                    tm_loss.append(1.0 * self.model.Entropy(tfidf, re_x))
                    # tm_loss = self.model.Entropy_v1(bow, rec_x)

                    re_x = torch.matmul(theta[-1], phi[-1].t())
                    tm_loss.append(1.0 * self.model.Entropy(theta[-2].detach(), re_x))


                    tm_loss = torch.mean(torch.stack(tm_loss))
                    layer_loss = torch.sum(torch.stack(layer_loss))
                    loss = 1.0 * tm_loss + layer_loss
                    self.optimizer.zero_grad()
                    loss.backward()
                    for p in self.trainable_params:
                        try:
                            p.grad = p.grad.where(~torch.isnan(p.grad), torch.tensor(0., device=p.grad.device))
                            p.grad = p.grad.where(~torch.isinf(p.grad), torch.tensor(0., device=p.grad.device))
                            nn.utils.clip_grad_norm_(p, max_norm=20, norm_type=2)
                        except:
                            pass
                    self.optimizer.step()
                    tm_loss_tr.append(tm_loss.item())
                    ct_loss_tr.append(layer_loss.item())
                    tr_loss.append(loss.item())
                    pbar.set_description(
                        f'epoch: {epoch}|{self.epoch}, loss: {np.mean(tr_loss):.4f},  ct loss: {np.mean(ct_loss_tr):.4f},  '
                        f'TM_loss: {np.mean(tm_loss_tr):.4f}')
                print(f'save voc')
                phi_cpu = [each.cpu().detach().numpy() for each in phi]
                vision_phi(phi_cpu, outpath=f'{self.log_str}/pre_{epoch}/{layer_id}/{i}',
                           voc=self.model.voc)

        else:
            layer_id = 0
            for epoch in range(self.epoch):
                self.model.train()
                pbar = tqdm(enumerate(train_loader), total=len(train_loader))
                for j, tfidf in pbar:
                    # bow_norm = bow / torch.sum(bow, dim=1, keepdim=True)
                    # bow = bow.to(self.device).float()
                    self.train_num += 1
                    # tfidf = tfidf.to(self.device).float()
                    tfidf = tfidf.to(self.device).float()
                    # tfidf = bow

                    ### layer-wise training
                    theta = self.model.InferNet(tfidf)
                    self.model.update_embeddings()
                    phi = self.model.cal_phi()
                    rec_x = torch.matmul(theta[0], phi[0].t())
                    # tm_loss = self.model.Poisson_likelihood(bow, rec_x)
                    tm_loss = self.model.Entropy(tfidf, rec_x)
                    # tm_loss = self.model.Entropy_v1(bow, rec_x)

                    forward_cost, backward_cost = self.model.GCT(tfidf, theta[layer_id], self.model.rho.detach(),
                                                                 self.model.alpha[layer_id])
                    loss = 1.0 * tm_loss + forward_cost + backward_cost

                    self.optimizer.zero_grad()
                    loss.backward()
                    for p in self.trainable_params:
                        try:
                            p.grad = p.grad.where(~torch.isnan(p.grad), torch.tensor(0., device=p.grad.device))
                            p.grad = p.grad.where(~torch.isinf(p.grad), torch.tensor(0., device=p.grad.device))
                            nn.utils.clip_grad_norm_(p, max_norm=20, norm_type=2)
                        except:
                            pass
                    self.optimizer.step()
                    tr_loss.append(loss.item())
                    tr_forward_cost.append(forward_cost.item())
                    tr_backward_cost.append(backward_cost.item())
                    tr_tm.append(tm_loss.item())
                    pbar.set_description(
                        f'layer_id: {layer_id}, epoch: {epoch}|{self.epoch}|{i}|{inner_epoch}, loss: {np.mean(tr_loss):.4f},  forword_cost: {np.mean(tr_forward_cost):.4f},  '
                        f'backward_cost: {np.mean(tr_backward_cost):.4f}, TM_loss: {np.mean(tr_tm):.4f}')
